import argparse

import pandas as pd
from llama_index.core.schema import TextNode

from iflytech_assistant.assistant.dataclasses import LianaiRagData
from iflytech_assistant.es import index

parser = argparse.ArgumentParser()

parser.add_argument("file", help="file path")
parser.add_argument("-i", "--index", help="index name")

args = parser.parse_args()

INDEX_NAME = args.index

df = pd.read_csv(args.file)
# filter clickrate >= 0.08
df = df[df["clickrate"] >= 0.08]

# only keep exposure_content,userinput
df = df[["exposure_content", "userinput"]]

# group by userinput
df = df.groupby("userinput").agg(lambda x: "\n".join(x)).reset_index()

tags = [
    "傲娇御姐",
    "元气满满",
    "活泼开朗",
    "阴阳怪气",
    "细腻暖男",
    "治愈天使",
    "直球出击",
    "糖果甜心",
    "浪漫诗人",
    "怼一怼",
    "猫系女友",
    "温柔宠溺",
    "彩虹夸夸",
    "花式道歉",
    "霸道总裁",
    "爹系男友",
    "贴心小奶狗",
    "文艺青年",
    "成熟大叔",
    "贴心蓝颜",
    "天才学霸",
    "阳光少年",
    "柔弱萝莉",
    "钓系美人",
    "情场高手",
    "放浪不羁",
    "温润如玉",
    "搞笑逗比",
    "恋爱脑",
    "知心姐姐",
    "刁蛮大小姐",
    "哄哄TA",
    "害羞宅男",
    "风流才子",
    "纯爱战神",
    "搞笑女",
    "古灵精怪",
    "欲擒故纵",
    "暧昧拉扯",
    "斩男茶语",
    "撒娇卖萌",
]

genders = ["男", "女", "default"]

stages = ["初识期", "热恋期", "表白期", "暧昧期", "稳定期", "default"]

nodes = []
for i, row in df.iterrows():
    for tag in tags:
        for gender in genders:
            for stage in stages:
                data: LianaiRagData = LianaiRagData(
                    input=row["userinput"],
                    gender=gender,
                    tag=tag,
                    stage=stage,
                    examples=row["exposure_content"].split("\n"),
                )
        node: TextNode = data.to_text_node()
        nodes.append(node)
index(nodes, INDEX_NAME)
